{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sklearn compatible Grid Search for regression\n",
    "\n",
    "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n",
    "the candidates searched. For regression it uses the same priniciple to return a deterministic regressor with the lowest empirical error subject to the constraint of bounded group loss. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn import preprocessing\n",
    "\n",
    "from aif360.sklearn.inprocessing import GridSearchReduction\n",
    "\n",
    "from aif360.sklearn.datasets import fetch_lawschool_gpa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
    "\n",
    "For example, we can easily load the law school gpa dataset from tempeh with the following line:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>lsat</th>\n",
       "      <th>ugpa</th>\n",
       "      <th>race</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>race</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <th>black</th>\n",
       "      <td>38.0</td>\n",
       "      <td>3.3</td>\n",
       "      <td>black</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <th>white</th>\n",
       "      <td>34.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>white</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>white</th>\n",
       "      <td>34.0</td>\n",
       "      <td>3.9</td>\n",
       "      <td>white</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <th>white</th>\n",
       "      <td>45.0</td>\n",
       "      <td>3.3</td>\n",
       "      <td>white</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <th>white</th>\n",
       "      <td>39.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>white</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         lsat  ugpa   race\n",
       "  race                    \n",
       "0 black  38.0   3.3  black\n",
       "1 white  34.0   4.0  white\n",
       "2 white  34.0   3.9  white\n",
       "3 white  45.0   3.3  white\n",
       "4 white  39.0   2.5  white"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train, y_train = fetch_lawschool_gpa(subset=\"train\")\n",
    "X_test, y_test = fetch_lawschool_gpa(subset=\"test\")\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then map the protected attributes to integers,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train.index = pd.MultiIndex.from_arrays(X_train.index.codes, names=X_train.index.names)\n",
    "X_test.index = pd.MultiIndex.from_arrays(X_test.index.codes, names=X_test.index.names)\n",
    "y_train.index = pd.MultiIndex.from_arrays(y_train.index.codes, names=y_train.index.names)\n",
    "y_test.index = pd.MultiIndex.from_arrays(y_test.index.codes, names=y_test.index.names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use Pandas for one-hot encoding for easy reference to columns associated with protected attributes, information necessary for grid search reduction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>lsat</th>\n",
       "      <th>ugpa</th>\n",
       "      <th>race_black</th>\n",
       "      <th>race_white</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>race</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>38.0</td>\n",
       "      <td>3.3</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <th>1</th>\n",
       "      <td>34.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>1</th>\n",
       "      <td>34.0</td>\n",
       "      <td>3.9</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <th>1</th>\n",
       "      <td>45.0</td>\n",
       "      <td>3.3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <th>1</th>\n",
       "      <td>39.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        lsat  ugpa  race_black  race_white\n",
       "  race                                    \n",
       "0 0     38.0   3.3           1           0\n",
       "1 1     34.0   4.0           0           1\n",
       "2 1     34.0   3.9           0           1\n",
       "3 1     45.0   3.3           0           1\n",
       "4 1     39.0   2.5           0           1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train, X_test = pd.get_dummies(X_train), pd.get_dummies(X_test)\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We normalize the continuous values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>lsat</th>\n",
       "      <th>ugpa</th>\n",
       "      <th>race_black</th>\n",
       "      <th>race_white</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>race</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <th>0</th>\n",
       "      <td>0.729730</td>\n",
       "      <td>0.825</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <th>1</th>\n",
       "      <td>0.621622</td>\n",
       "      <td>1.000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>1</th>\n",
       "      <td>0.621622</td>\n",
       "      <td>0.975</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <th>1</th>\n",
       "      <td>0.918919</td>\n",
       "      <td>0.825</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <th>1</th>\n",
       "      <td>0.756757</td>\n",
       "      <td>0.625</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            lsat   ugpa  race_black  race_white\n",
       "  race                                         \n",
       "0 0     0.729730  0.825         1.0         0.0\n",
       "1 1     0.621622  1.000         0.0         1.0\n",
       "2 1     0.621622  0.975         0.0         1.0\n",
       "3 1     0.918919  0.825         0.0         1.0\n",
       "4 1     0.756757  0.625         0.0         1.0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_max_scaler = preprocessing.MinMaxScaler()\n",
    "X_train = pd.DataFrame(min_max_scaler.fit_transform(X_train.values),columns=list(X_train),index=X_train.index)\n",
    "X_test = pd.DataFrame(min_max_scaler.transform(X_test.values),columns=list(X_test),index=X_test.index)\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_max_scaler = preprocessing.MinMaxScaler()\n",
    "y_train = pd.Series(min_max_scaler.fit_transform(y_train.values.reshape(-1, 1)).flatten(),index=y_train.index)\n",
    "y_test = pd.Series(min_max_scaler.transform(y_test.values.reshape(-1, 1)).flatten(),index=y_test.index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The protected attribute information is also replicated in the labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "   race\n",
       "0  0       0.488636\n",
       "1  1       0.688131\n",
       "2  1       0.398990\n",
       "3  1       0.758838\n",
       "4  1       0.482323\n",
       "dtype: float64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data. We drop the protective attribule columns so that they are not used in the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "prot_attr_cols = [col for col in list(X_train) if \"race\" in col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.09344477678851784\n"
     ]
    }
   ],
   "source": [
    "lr = LinearRegression().fit(X_train.drop(prot_attr_cols,axis=1), y_train)\n",
    "y_pred = lr.predict(X_test.drop(prot_attr_cols, axis=1))\n",
    "lr_mae = mean_absolute_error(y_test, y_pred)\n",
    "print(lr_mae)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can assess how the mean absolute error differs across groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "White: 0.09151357295567962\n"
     ]
    }
   ],
   "source": [
    "X_test_white = X_test.iloc[X_test.index.get_level_values('race') == 1]\n",
    "y_test_white = y_test.iloc[y_test.index.get_level_values('race') == 1]\n",
    "\n",
    "y_pred_white = lr.predict(X_test_white.drop(prot_attr_cols, axis=1))\n",
    "\n",
    "lr_mae_w = mean_absolute_error(y_test_white, y_pred_white)\n",
    "print(\"White:\", lr_mae_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Black: 0.11726179331646831\n"
     ]
    }
   ],
   "source": [
    "X_test_black = X_test.iloc[X_test.index.get_level_values('race') == 0]\n",
    "y_test_black = y_test.iloc[y_test.index.get_level_values('race') == 0]\n",
    "\n",
    "y_pred_black = lr.predict(X_test_black.drop(prot_attr_cols, axis=1))\n",
    "\n",
    "lr_mae_b = mean_absolute_error(y_test_black, y_pred_black)\n",
    "print(\"Black:\", lr_mae_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean absolute error difference across groups: 0.025748220360788693\n"
     ]
    }
   ],
   "source": [
    "print(\"Mean absolute error difference across groups:\", lr_mae_b-lr_mae_w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Grid Search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a base model for the candidate regressors. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = LinearRegression()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Search for the best regressor and observe mean absolute error. Grid search for regression uses \"GroupLoss\" to specify using bounded group loss for its constraints. Accordingly we need to specify a loss function, like \"Absolute.\" Other options include \"Square\" and \"ZeroOne.\" When the loss is \"Absolute\" or \"Square\" we also specify the expected range of the y values in min_val and max_val. For details on the implementation of these loss function see the fairlearn library here https://github.com/fairlearn/fairlearn/blob/master/fairlearn/reductions/_moments/bounded_group_loss.py."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.09624645677710374\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0) #need for reproducibility\n",
    "grid_search_red = GridSearchReduction(prot_attr=prot_attr_cols, \n",
    "                                      estimator=estimator, \n",
    "                                      constraints=\"GroupLoss\",\n",
    "                                      loss=\"Absolute\",\n",
    "                                      min_val=0,\n",
    "                                      max_val=1,\n",
    "                                      grid_size=10,\n",
    "                                      drop_prot_attr=True)\n",
    "grid_search_red.fit(X_train, y_train)\n",
    "gs_pred = grid_search_red.predict(X_test)\n",
    "gs_mae = mean_absolute_error(y_test, gs_pred)\n",
    "print(gs_mae)\n",
    "\n",
    "#Check if mean absolute error is comparable\n",
    "assert abs(gs_mae-lr_mae)<0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "White: 0.09566668133321606\n"
     ]
    }
   ],
   "source": [
    "gs_mae_w = mean_absolute_error(y_test_white, grid_search_red.predict(X_test_white))\n",
    "print(\"White:\", gs_mae_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Black: 0.1033966711122104\n"
     ]
    }
   ],
   "source": [
    "gs_mae_b = mean_absolute_error(y_test_black, grid_search_red.predict(X_test_black))\n",
    "print(\"Black:\", gs_mae_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean absolute error difference across groups: 0.007729989778994348\n"
     ]
    }
   ],
   "source": [
    "print(\"Mean absolute error difference across groups:\", gs_mae_b-gs_mae_w)\n",
    "\n",
    "#Check if difference decreased\n",
    "assert (gs_mae_b-gs_mae_w)<(lr_mae_b-lr_mae_w)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}